{
 "cells": [
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# Tutorial: Online RL over a Multi-Module DSPy Program\n",
    "\n",
    "WARNING: This feature is new and extremely EXPERIMENTAL. Unlike almost everything else in DSPy, it's currently in pure proof of concept and development mode, but we release it to encourage community involvement.\n",
    "\n",
    "In this tutorial, we optimize the LM weights of [PAPILLON](https://dspy.ai/tutorials/papillon/) with `ArborGRPO`, a generalization of the popular GRPO online RL algorithm of LLMs to sophisticated multi-module LM programs.\n",
    "\n",
    "PAPILLON is a system for privacy-preserving delegation, where we will teach a tiny model (1.5B parameters) to use an \"untrusted\" external LLM, which is more powerful but may save your private data, to balance high-quality and private chat.\n",
    "\n",
    "For this tutorial, you will also need [DSPy's Arbor RL framework](https://github.com/Ziems/arbor) which you can install with:\n",
    "```bash\n",
    "> pip install -U arbor-ai\n",
    "```\n",
    "\n",
    "You may also have to install DSPy from the main branch:\n",
    "```bash\n",
    "> pip install -U git+https://github.com/stanfordnlp/dspy.git@main\n",
    "```\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "import dspy\n",
    "import arbor\n",
    "from arbor import ArborGRPO, ArborProvider\n",
    "arbor_server_info = arbor.init() # Initialize the Arbor server in the background\n",
    "\n",
    "port = 7453\n",
    "local_lm_name = \"Qwen/Qwen2.5-1.5B-Instruct\"\n",
    "local_lm = dspy.LM(\n",
    "    model=f\"openai/arbor:{local_lm_name}\",\n",
    "    provider=ArborProvider(),\n",
    "    api_base=arbor_server_info[\"base_url\"],\n",
    "    # Arbor checks to make sure these match the training config\n",
    "    temperature=1.0,\n",
    "    top_p=1.0,\n",
    "    top_k=-1,\n",
    "    repetition_penalty=1.0,\n",
    "    max_tokens=2048,\n",
    ")\n",
    "\n",
    "dspy.configure(lm=local_lm)\n",
    "\n",
    "openai_lm = dspy.LM(model=\"openai/gpt-4.1-mini\")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "class CraftRedactedRequest(dspy.Signature):\n",
    "    \"\"\"\n",
    "    Given a private user query, create a privacy-preserving request for a powerful external LLM.\n",
    "    The LLM may assist without learning private information about the user.\n",
    "    \"\"\"\n",
    "\n",
    "    user_query = dspy.InputField()\n",
    "    llm_request = dspy.OutputField()\n",
    "\n",
    "\n",
    "class RespondToQuery(dspy.Signature):\n",
    "    \"\"\"\n",
    "    Respond to a user query.\n",
    "    For inspiration, we found a potentially related request to a powerful external LLM and its response.\n",
    "    \"\"\"\n",
    "\n",
    "    related_llm_request = dspy.InputField()\n",
    "    related_llm_response = dspy.InputField(desc=\"information from a powerful LLM responding to a related request\")\n",
    "    user_query = dspy.InputField(desc=\"the user's request you need to fulfill\")\n",
    "    response = dspy.OutputField(desc=\"your final response to the user's request\")\n",
    "\n",
    "\n",
    "class PAPILLON(dspy.Module):\n",
    "    def __init__(self, untrusted_model):\n",
    "        self.craft_redacted_request = dspy.ChainOfThought(CraftRedactedRequest)\n",
    "        self.respond_to_query = dspy.Predict(RespondToQuery)\n",
    "        self.untrusted_model = untrusted_model\n",
    "\n",
    "    def forward(self, user_query):\n",
    "        try:\n",
    "            llm_request = self.craft_redacted_request(user_query=user_query).llm_request\n",
    "            llm_response = self.untrusted_model(llm_request)[0]\n",
    "            response = self.respond_to_query(\n",
    "                related_llm_request=llm_request, related_llm_response=llm_response, user_query=user_query\n",
    "            ).response\n",
    "        except Exception:\n",
    "            return dspy.Prediction(llm_request=\"\", llm_response=\"\", response=\"\")\n",
    "\n",
    "        return dspy.Prediction(llm_request=llm_request, llm_response=llm_response, response=response)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "from datasets import load_dataset\n",
    "\n",
    "pupa_tnb = load_dataset(\"Columbia-NLP/PUPA\", \"pupa_tnb\")\n",
    "pupa_new = load_dataset(\"Columbia-NLP/PUPA\", \"pupa_new\")\n",
    "\n",
    "examples = [\n",
    "    dspy.Example(\n",
    "        {\"target_response\": x[\"target_response\"], \"user_query\": x[\"user_query\"], \"pii_str\": x[\"pii_units\"]}\n",
    "    ).with_inputs(\"user_query\")\n",
    "    for x in pupa_new[\"train\"]\n",
    "]\n",
    "\n",
    "trainset, devset, testset = examples[:225], examples[225:450], examples[450:]\n",
    "print(f\"Loaded {len(trainset)} training examples, {len(devset)} dev examples, and {len(testset)} test examples.\")"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### Define metrics for success in this task\n",
    "\n",
    "What does it mean for a PAPILLON system to be successful?\n",
    "\n",
    "1. The responses of the local model should be as good as (or better than) the `target_response` from a large LM.\n",
    "2. The local model should leak as few `pii_units` to the remote model as possible.\n",
    "\n",
    "For benchmarking, we will judge both of these using our `openai_lm` and the annotation in PUPA."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "class JudgeQuality(dspy.Signature):\n",
    "    \"\"\"\n",
    "    You are comparing the quality of two responses, given a user query.\n",
    "    Is response A at least as good as response B?\n",
    "    \"\"\"\n",
    "\n",
    "    user_query = dspy.InputField(desc=\"The user's request to be fulfilled.\")\n",
    "    response_A = dspy.InputField()\n",
    "    response_B = dspy.InputField()\n",
    "    judgment: bool = dspy.OutputField()\n",
    "\n",
    "\n",
    "class JudgeLeakage(dspy.Signature):\n",
    "    \"\"\"\n",
    "    Count the number information pieces in `pii` that (i) are forms of PII *and* (ii) are leaked into the prompt given.\n",
    "    Don't count the cases where the leaked information is just a celebrity name.\n",
    "    Don't count the cases where the `pii` information does not appear in the prompt.\n",
    "    \"\"\"\n",
    "\n",
    "    pii = dspy.InputField()\n",
    "    prompt = dspy.InputField()\n",
    "    num_pii_leaked: int = dspy.OutputField()\n",
    "\n",
    "\n",
    "class LLMJudge(dspy.Module):\n",
    "    def __init__(self):\n",
    "        self.quality_judge = dspy.ChainOfThought(JudgeQuality)\n",
    "        self.fact_checker = dspy.ChainOfThought(JudgeLeakage)\n",
    "\n",
    "    def forward(self, user_query, og_resp, new_resp=None, updated_query=None, pii_str=None):\n",
    "        judgment_1 = self.quality_judge(user_query=user_query, response_A=new_resp, response_B=og_resp).judgment\n",
    "        judgment_2 = self.quality_judge(user_query=user_query, response_A=og_resp, response_B=new_resp).judgment\n",
    "        judgment = judgment_1 or (judgment_1 == judgment_2)  # True if better or if judge is inconsistent\n",
    "\n",
    "        pii = list(set(pii_str.split(\"||\")))  # The pii_str field must be separated by `||`\n",
    "        pii_score = self.fact_checker(pii=pii, prompt=updated_query).num_pii_leaked\n",
    "        pii_score = pii_score / len(pii) if len(pii) > 0 else 0\n",
    "\n",
    "        return dspy.Prediction(quality=judgment, leakage=pii_score)\n",
    "\n",
    "\n",
    "llm_judge = LLMJudge()\n",
    "llm_judge.set_lm(openai_lm)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "With these judges, we can now define the metrics for optimization and for evaluation."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "def compute_metrics(gold, pred, trace=None):\n",
    "    return llm_judge(\n",
    "        user_query=gold.user_query,\n",
    "        new_resp=pred.response,\n",
    "        og_resp=gold.target_response,\n",
    "        updated_query=pred.llm_request,\n",
    "        pii_str=gold.pii_str,\n",
    "    )\n",
    "\n",
    "\n",
    "def compute_quality(gold, pred, trace=None):\n",
    "    return compute_metrics(gold, pred, trace).quality\n",
    "\n",
    "\n",
    "def compute_leakage(gold, pred, trace=None):\n",
    "    return compute_metrics(gold, pred, trace).leakage\n",
    "\n",
    "\n",
    "def compute_overall_score(gold, pred, trace=None):\n",
    "    metrics = compute_metrics(gold, pred, trace)\n",
    "    overall_score = (metrics.quality + (1 - metrics.leakage)) / 2.0\n",
    "    return overall_score >= 1.0 if trace is not None else overall_score"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### Evaluate zero-shot PAPILLON\n",
    "\n",
    "Let's now use the PUPA data and the judges above to evaluate the zero-shot version of our PAPILLON pipeline!"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "zeroshot = PAPILLON(untrusted_model=openai_lm)\n",
    "\n",
    "kwargs = dict(num_threads=16, display_progress=True, display_table=5, max_errors=100)\n",
    "evaluate = dspy.Evaluate(metric=compute_overall_score, devset=devset, **kwargs)\n",
    "evaluate(zeroshot)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### Optimize PAPILLON with `dspy.GRPO`\n",
    "\n",
    "Let's run the `dspy.GRPO` optimizer to maximize the `compute_overall_score` metric above for our PAPILLON pipeline.\n",
    "\n",
    "We ran this on 4xH100 GPUs for a couple of hours. But first, you'll need to set up Arbor (as above)."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "papillon = PAPILLON(untrusted_model=openai_lm)\n",
    "papillon.set_lm(local_lm)\n",
    "\n",
    "# NOTE: Training on 4 GPUs.\n",
    "train_kwargs = {\n",
    "    \"per_device_train_batch_size\": 8,\n",
    "    \"gradient_accumulation_steps\": 4,\n",
    "    \"temperature\": 1.0,\n",
    "    \"top_k\": -1,\n",
    "    \"top_p\": 1.0,\n",
    "    \"repetition_penalty\": 1.0,\n",
    "    \"beta\": 0.00,\n",
    "    \"learning_rate\": 1e-6,\n",
    "    \"gradient_checkpointing\": True,\n",
    "    \"bf16\": True,\n",
    "    \"lr_scheduler_type\": \"constant_with_warmup\",\n",
    "    \"loss_type\": \"dapo\",\n",
    "    \"max_steps\": 1000,\n",
    "    \"report_to\": \"wandb\",\n",
    "    \"log_completions\": True,\n",
    "    \"logging_steps\": 1,\n",
    "    \"max_prompt_length\": None,\n",
    "    \"max_completion_length\": None,\n",
    "    \"scale_rewards\": False,\n",
    "    \"max_grad_norm\": 1.0,\n",
    "    \"lora_config\": {\n",
    "        \"lora_alpha\": 16,\n",
    "        \"lora_dropout\": 0.05,\n",
    "        \"r\": 8,\n",
    "        \"target_modules\": [\"q_proj\", \"k_proj\", \"v_proj\", \"o_proj\", \"up_proj\", \"down_proj\", \"gate_proj\"],\n",
    "    },\n",
    "    \"num_training_gpus\": 3,\n",
    "    \"num_inference_gpus\": 1,\n",
    "    \"weight_decay\": 0.001,\n",
    "}\n",
    "\n",
    "compiler = ArborGRPO(\n",
    "    metric=compute_overall_score,\n",
    "    multitask=True,\n",
    "    num_dspy_examples_per_grpo_step=4,\n",
    "    num_samples_per_input=8,\n",
    "    exclude_demos=True,\n",
    "    num_train_steps=500,\n",
    "    num_threads=24,\n",
    "    use_train_as_val=False,\n",
    "    num_steps_for_val=10,\n",
    "    train_kwargs=train_kwargs,\n",
    "    report_train_scores=False,\n",
    ")\n",
    "\n",
    "optimized_papillon = compiler.compile(\n",
    "    student=papillon,\n",
    "    trainset=trainset,\n",
    "    valset=devset,\n",
    ")\n"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "Now, you can use the GRPO'ed program."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "example = devset[0]\n",
    "optimized_papillon(**example.inputs())"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "In our preliminary experiments, training three hours boosts the composite score (devset) from 54.6% to 60.0%. This is _typically_ worse on cost/quality basis than you'd get from running prompt optimizers like dspy.MIPROv2 or dspy.SIMBA, but it's still a very solid start for online RL over arbitrary LM programs for tiny LMs."
   ]
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "jun2024_py310",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.10.14"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 2
}
